
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import json
import argparse
import logging
import os
import re
from tqdm import tqdm
from openai import OpenAI


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def build_cot_prompts(instruction, output):
    rv_prompt_template = (
        "You are an expert judge tasked with evaluating the Reasoning Verbosity of a Chain-of-Thought (CoT) "
        "for a given problem and its answer. Reasoning Verbosity Evaluation Focus: Assess how well the CoT’s "
        "length and step complexity match the problem’s inherent difficulty. An optimal chain is neither "
        "missing essential steps nor padded with needless digressions. A simple question should be solved "
        "with a brief, direct chain; a challenging one may justifiably require a longer path with reflection "
        "and error-checking. Scoring Guidelines (0-9):\n"
        "0-1 Minimal verbosity, straightforward expression with little to no elaboration.\n"
        "2-3 Clear and concise reasoning with necessary explanations.\n"
        "4-5 Moderate verbosity with detailed explanations and thorough reasoning.\n"
        "6-7 Extensive verbosity with comprehensive justification and exploration of complex connections.\n"
        "8-9 High verbosity with deep, exhaustive exploration of reasoning; involves extensive elaboration, nested justifications, "
        "and consideration of counterarguments or alternative perspectives.\n"
        "Given Problem, Answer with hain-of-Thought, you will:\n"
        "1. Analyze the Reasoning Verbosity\n"
        "2. Determine score using the above criteria\n"
        "3. Output ONLY the integer score (0-9), place your score in <score></score>\n"
        f"Problem: {instruction}\n"
        f"Answer with Chain-of-Thought: {output}"
    )
    cd_prompt_template = (
        "You are an expert judge assessing the Cognitive Difficulty of a Chain-of-Thought (CoT) "
        "for a given problem and its answer. Cognitive Difficulty Evaluation Focus: The level of "
        "reasoning competence required for a model to follow and reproduce the chain faithfully. "
        "Judge the reasoning approach, techniques, and overall difficulty. Higher scores correspond "
        "to more advanced concepts, abstractions, or multi-layer reasoning patterns. "
        "Scoring Guidelines (0-9):\n"
        "0-1 Elementary facts or a single trivial operation.\n"
        "2-3 Multi-step arithmetic, explicit enumeration, basic rule chaining.\n"
        "4-5 Early-undergraduate logic/algebra; one non-obvious insight.\n"
        "6-7 Advanced undergraduate techniques (determinants, dynamic programming, layered code reasoning, etc).\n"
        "8-9 Graduate-level abstraction, nested proofs, intricate algorithmic analysis.\n"
        "Given Problem, Answer with hain-of-Thought, you will:\n"
        "1. Analyze the Cognitive Difficulty\n"
        "2. Determine score using the above criteria\n"
        "3. Output ONLY the integer score (0-9), place your score in <score></score>\n"
        f"Problem: {instruction}\n"
        f"Answer with Chain-of-Thought: {output}"
    )
    lc_prompt_template = (
        "You are a rigorous logical validator analyzing problem-solving components. "
        "Your task is to separately assess the validity of the reasoning process and final solution. "
        "Given Problem, Answer with hain-of-Thought, you will:\n"
        "1. Verify stepwise logical coherence and soundness\n"
        "2. Confirm all critical problem constraints are properly addressed\n"
        "3. Check for self-contradictions or unsupported leaps in logic\n"
        "4. Verify the process can actually derive the proposed solution\n"
        "5. Output ONLY the 1/0 answer (1 for true, 0 for false) for logical correctness, place your answer in <score></score>\n"
        f"Problem: {instruction}\n"
        f"Answer with Chain-of-Thought: {output}"    
    )
    return rv_prompt_template, cd_prompt_template, lc_prompt_template


def build_instruct_prompts(instruction, output):
    informativeness_template = (
        "You are an expert judge tasked with evaluating the Informativeness of a response generated by an instruction-following model "
        "for a given user instruction. Informativeness Evaluation Focus: Assess how thoroughly and accurately the response addresses "
        "the user’s instruction, providing relevant details, facts, and explanations without omissions or irrelevant additions. "
        "An informative response fully satisfies the query with meaningful content, whereas a less informative one may be vague, "
        "incomplete, or superficial. Scoring Guidelines (0-9):\n"
        "0-1 Very low informativeness; the response is irrelevant or nearly empty.\n"
        "2-3 Low informativeness; addresses the instruction minimally with significant missing information.\n"
        "4-5 Moderate informativeness; covers some key points but lacks depth or completeness.\n"
        "6-7 High informativeness; provides detailed and mostly comprehensive information relevant to the instruction.\n"
        "8-9 Exceptional informativeness; thoroughly and accurately covers all relevant aspects with rich and precise details.\n"
        "Given Instruction and Model Response, you will:\n"
        "1. Analyze the Informativeness of the response\n"
        "2. Determine a score using the above criteria\n"
        "3. Output ONLY the integer score (0-9), place your score in <score></score>\n"
        f"Instruction: {instruction}\n"
        f"Response: {output}"
    )
    helpfulness_template = (
        "You are an expert judge tasked with evaluating the Helpfulness of a response generated by an instruction-following model "
        "for a given user instruction. Helpfulness Evaluation Focus: Assess how well the response assists the user in accomplishing  "
        "their goal, providing clear, actionable, and relevant information or guidance. A helpful response should be easy "
        "to understand and effectively address the user’s needs without unnecessary confusion or missing key details.\n"
        "Scoring Guidelines (0-9):\n"
        "0-1 Not helpful; response is irrelevant, confusing, or fails to address the instruction.\n"
        "2-3 Slightly helpful; responds partially but lacks clarity or important elements.\n"
        "4-5 Moderately helpful; response addresses the instruction but may be incomplete or somewhat unclear.\n"
        "6-7 Mostly helpful; provides clear and relevant information that adequately assists the user.\n"
        "8-9 Extremely helpful; offers comprehensive, clear, and precise guidance or information that fully satisfies the user’s instruction.\n"
        "Given Instruction and Model Response, you will:\n"
        "1. Analyze the Helpfulness of the response\n"
        "2. Determine a score using the above criteria\n"
        "3. Output ONLY the integer score (0-9), place your score in <score></score>\n"
        f"Instruction: {instruction}\n"
        f"Response: {output}"
    )
    generalization_template = (
        "You are an expert judge tasked with evaluating the Potential for Generalization of a response generated by an "
        "instruction-following model to similar but unseen tasks. Generalization Evaluation Focus: Assess how well the response "
        "demonstrates understanding and reasoning that can be effectively adapted or transferred to other related instructions or "
        "problems beyond the specific input. A response with high generalization ability "
        "captures underlying principles or strategies rather than relying on shallow, task-specific heuristics.\n"
        "Scoring Guidelines (0-9):\n"
        "0-1 Very poor generalization; response is overly specific, rigid, or fails to show adaptable reasoning.\n"
        "2-3 Limited generalization; response applies partly to related tasks but is mostly narrow or shallow.\n"
        "4-5 Moderate generalization; response reflects some transferable understanding but may lack depth or clarity.\n"
        "6-7 Strong generalization; response shows clear reasoning patterns or concepts that can extend to similar tasks.\n"
        "8-9 Exceptional generalization; response exhibits deep, abstract, and flexible comprehension applicable across a broad range of related instructions.\n"
        "Given Instruction and Model Response, you will:\n"
        "1. Analyze the Potential for Generalization to Similar Tasks\n"
        "2. Determine a score using the above criteria\n"
        "3. Output ONLY the integer score (0-9), place your score in <score></score>\n"
        f"Instruction: {instruction}\n"
        f"Response: {output}"
    )
    correctness_template = (
        "You are a meticulous correctness evaluator tasked with assessing whether the response to a user instruction "
        "is factually accurate and logically sound.\n"
        "Your evaluation should determine:\n"
        "1. Whether the response correctly addresses the instruction\n"
        "2. Whether any factual claims or data are accurate\n"
        "3. Whether the reasoning, if present, is logically valid and free of errors\n"
        "4. Whether the final answer is consistent with the evidence or instructions provided\n"
        "You will:\n"
        "Output ONLY '1' if the response is correct and accurate, or '0' if it contains factual errors, logical flaws, "
        "or fails to correctly address the instruction.\n"
        "Place your answer in <score></score> tags.\n"
        f"Instruction: {instruction}\n"
        f"Response: {output}"
    )
    return informativeness_template, helpfulness_template, generalization_template, correctness_template



def extract_score(text):
    match = re.search(r"<score>(\d+)</score>", text)
    if match:
        return int(match.group(1))
    else:
        return -1


def read_json_fields(filename):
    try:
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError:
        logging.error("The file was not found.")
    except json.JSONDecodeError:
        logging.error("There was an error decoding the JSON file.")
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def write_data_to_json_file(data, file_path):
    try:
        with open(file_path, 'w') as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        logging.info(f"Data successfully written to {file_path}")
    except Exception as e:
        logging.error(f"An error occurred: {e}")

        
def generate_teacher_response_api(data_list, config, is_cot_model):
    client = OpenAI(
        api_key = config["inference"]["api_key"],
        base_url = config["inference"]["base_url"]
    )
    models = client.models.list()
    model = models.data[0].id
    logging.info(model)
    outcomes = []
    for sample in tqdm(data_list, desc="Call remote model and generating responses"):
        instruction = sample["instruction"]
        output = sample["output"]
        
        def generate_score(sample, model, config):
            message = [
                {'role': 'user', 'content': sample}
            ]
            completion = client.chat.completions.create(
                messages = message,
                model = model,
                max_completion_tokens = config["inference"]["max_new_tokens"]
            )
            result = completion.choices[0].message.content
            score = extract_score(result)
            return score
            
        if is_cot_model:
            rv_prompt_template, cd_prompt_template, lc_prompt_template = build_cot_prompts(instruction, output)
            rv_score = generate_score(rv_prompt_template, model, config)
            cd_score = generate_score(cd_prompt_template, model, config)
            lc_score = generate_score(lc_prompt_template, model, config)
            lc_score = (lc_score == 1)
            outcomes.append(
                {
                    'instruction': instruction,
                     'output': output,
                     "reasoning_verbosity": rv_score,
                     "cognitive_difficulty": cd_score,
                     "logical_correctness": lc_score
                }
            )
        else:
            informativeness_temp, helpfulness_temp, generalization_temp, correctness_temp = build_instruct_prompts(instruction, output)
            informativeness = generate_score(informativeness_temp, model, config)
            helpfulness = generate_score(helpfulness_temp, model, config)
            generalization = generate_score(generalization_temp, model, config)
            correctness = generate_score(correctness_temp, model, config)
            correctness = (correctness == 1)
            outcomes.append(
                {
                    'instruction': instruction,
                     'output': output,
                     "informativeness": informativeness,
                     "helpfulness": helpfulness,
                     "generalization": generalization,
                    "correctness": correctness
                }
            )
            
    write_data_to_json_file(outcomes, config["dataset"]["output_path"])


def infer_with_teacher_model(config):
    logging.info('Generating distillation data from the teacher model!')
    data_list = read_json_fields(config["dataset"]["input_path"])
    job_type =  config["job_type"]
    is_cot_model = "cot" in job_type
    generate_teacher_response_api(data_list, config, is_cot_model)

        
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='path to the json config file')
    args = parser.parse_args()
    config = json.load(open(args.config))
    infer_with_teacher_model(config)


if __name__ == "__main__":
    main()